{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.6.2 FETI-DP in NGSolve II: Point-Constraints in 3D\n",
    "\n",
    "We implement standard FETI-DP, using only point-constraints, for Poisson's equation in 3D."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_id = 'lukas'\n",
    "num_procs = '40'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from usrmeeting_jupyterstuff import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster(user_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "could not start cluster, (probably already/still running)\n",
      "\r",
      "connecting ... try:0 succeeded!"
     ]
    }
   ],
   "source": [
    "start_cluster(num_procs,user_id)\n",
    "connect_cluster(user_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:16] \n",
      "global,  ndof = 75403 , lodofs = 10355\n",
      "avg DOFs per core:  2260.7\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "from ngsolve import *\n",
    "import netgen.meshing as ngmeshing\n",
    "\n",
    "nref = 0\n",
    "\n",
    "dim=3\n",
    "ngmesh = ngmeshing.Mesh(dim=dim)\n",
    "ngmesh.Load('cube.vol')\n",
    "for l in range(nref):\n",
    "    ngmesh.Refine()\n",
    "mesh = Mesh(ngmesh)\n",
    "comm = MPI_Init()\n",
    "fes = H1(mesh, order=2, dirichlet='right|top|top|left')\n",
    "pardofs = fes.ParallelDofs()\n",
    "a = BilinearForm(fes)\n",
    "u,v = fes.TnT()\n",
    "a += SymbolicBFI(grad(u)*grad(v))\n",
    "a.Assemble()\n",
    "f = LinearForm(fes)\n",
    "f += SymbolicLFI(x*y*v)\n",
    "f.Assemble()\n",
    "avg_dof = comm.Sum(fes.ndof) / comm.size\n",
    "if comm.rank==0:\n",
    "    print('global,  ndof =', fes.ndofglobal, ', lodofs =', fes.lospace.ndofglobal)\n",
    "    print('avg DOFs per core: ', avg_dof)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finding the primal DOFs\n",
    "\n",
    "In 3 dimensions, we cannot classify the primal DOFs by their multiplicity.\n",
    "\n",
    "Hoever, we can very easily find subdomain **faces**.\n",
    "Then, we simply define subdomain edges as intersections of faces, and finally subdomain vertices\n",
    "as intersections of edges.\n",
    "\n",
    "We can do this compactly with generator expressions, or in a readable fashion\n",
    "with for-loops.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "pythonic = True\n",
    "if pythonic:\n",
    "    faces = [set(d for d in pardofs.Proc2Dof(p) if d<mesh.nv and fes.FreeDofs()[d] ) for p in pardofs.ExchangeProcs()]\n",
    "    edges = sorted([tuple(sorted(e)) for e in set(tuple(f1.intersection(f2)) for f1 in faces for f2 in faces if f1 is not f2) if len(e)>1])\n",
    "    vertices = sorted(set([ v for e1 in edges for e2 in edges if e1 is not e2 for v in set(e1).intersection(set(e2)) ]))\n",
    "else:\n",
    "    faces = []\n",
    "    for p in pardofs.ExchangeProcs():\n",
    "        faces.append(set(d for d in pardofs.Proc2Dof(p) if d<mesh.nv and fes.FreeDofs()[d]))\n",
    "    edges = []\n",
    "    for f1 in faces:\n",
    "        for f2 in faces:\n",
    "            if f1 is not f2:\n",
    "                edge = sorted(tuple(f1.intersection(f2)))\n",
    "                if len(edge) > 1:\n",
    "                    if not edge in edges:\n",
    "                        edges.append(sorted(tuple(edge)))\n",
    "    vertices = set()\n",
    "    for e1 in edges:\n",
    "        for e2 in edges:\n",
    "            if e1 is not e2:\n",
    "                vs = set(e1).intersection(set(e2))\n",
    "                vertices = vertices.union(vs)\n",
    "    vertices = sorted(vertices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is one problem left to consider. In order for a DOF to be included in \"vertices\" on some rank,\n",
    "that rank has to posess **two** edges that contain it. This is not always the case.\n",
    "\n",
    "We have to flag a DOF as a vertex if **any** rank thinks it should be one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:16] \n",
      "# primal dofs global:  122\n",
      "min, avg, max per rank:  4   12.425   25\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "vec = f.vec.CreateVector()\n",
    "vec.local_vec[:] = 0.0\n",
    "for v in vertices:\n",
    "    vec.local_vec[v] = 1\n",
    "from ngsolve.la import DISTRIBUTED\n",
    "vec.SetParallelStatus(DISTRIBUTED)\n",
    "vec.Cumulate()\n",
    "primal_dofs = BitArray([vec.local_vec[k]!=0 for k in range(fes.ndof)]) & fes.FreeDofs()\n",
    "\n",
    "nprim = comm.Sum(sum([1 for k in range(fes.ndof) if primal_dofs[k] and comm.rank<pardofs.Dof2Proc(k)[0] ]))\n",
    "npmin = comm.Min(primal_dofs.NumSet() if comm.rank else nprim)\n",
    "npavg = comm.Sum(primal_dofs.NumSet())/comm.size\n",
    "npmax = comm.Max(primal_dofs.NumSet())\n",
    "if comm.rank==0:\n",
    "    print('# primal dofs global: ', nprim)  \n",
    "    print('min, avg, max per rank: ', npmin, ' ', npavg, ' ', npmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster(user_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Everything else works the same - experiment below!!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Waiting for connection file: ~/.ipython/profile_ngsolve/security/ipcontroller-koglerlukas-client.json\n",
      "connecting ... try:6 succeeded!"
     ]
    }
   ],
   "source": [
    "from usrmeeting_jupyterstuff import *\n",
    "user_id = 'lukas'\n",
    "num_procs = '100'\n",
    "stop_cluster(user_id)\n",
    "start_cluster(num_procs, user_id)\n",
    "connect_cluster(user_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "from ngsolve import *\n",
    "import netgen.meshing as ngmeshing\n",
    "from ngsolve.la import ParallelMatrix, FETI_Jump\n",
    "from dd_toolbox import FindFEV, LocGlobInverse, ScaledMat\n",
    "\n",
    "def load_mesh(nref=0):\n",
    "    ngmesh = ngmeshing.Mesh(dim=3)\n",
    "    ngmesh.Load('cube.vol')\n",
    "    for l in range(nref):\n",
    "        ngmesh.Refine()\n",
    "    return Mesh(ngmesh)\n",
    "\n",
    "def setup_space(mesh, order=1):\n",
    "    comm = MPI_Init()\n",
    "    fes = H1(mesh, order=order, dirichlet='right|top')\n",
    "    a = BilinearForm(fes)\n",
    "    u,v = fes.TnT()\n",
    "    a += SymbolicBFI(grad(u)*grad(v))\n",
    "    a.Assemble()\n",
    "    f = LinearForm(fes)\n",
    "    f += SymbolicLFI(x*y*v)\n",
    "    f.Assemble()\n",
    "    avg_dof = comm.Sum(fes.ndof) / comm.size\n",
    "    if comm.rank==0:\n",
    "        print('global,  ndof =', fes.ndofglobal, ', lodofs =', fes.lospace.ndofglobal)\n",
    "        print('avg DOFs per core: ', avg_dof)\n",
    "    return [fes, a, f]\n",
    "\n",
    "def setup_FETIDP(fes, a):\n",
    "    faces, edges, vertices = FindFEV(mesh.dim, mesh.nv, \\\n",
    "                                     fes.ParallelDofs(), fes.FreeDofs())\n",
    "    primal_dofs = BitArray([ v in set(vertices) for v in range(fes.ndof) ]) & fes.FreeDofs() \n",
    "    dp_pardofs = fes.ParallelDofs().SubSet(primal_dofs)\n",
    "    nprim = comm.Sum(sum([1 for k in range(fes.ndof) if primal_dofs[k] and comm.rank<fes.ParallelDofs().Dof2Proc(k)[0] ]))\n",
    "    if comm.rank==0:\n",
    "        print('# of global primal dofs: ', nprim)  \n",
    "    A_dp = ParallelMatrix(a.mat.local_mat, dp_pardofs)\n",
    "    dual_pardofs = fes.ParallelDofs().SubSet(BitArray(~primal_dofs & fes.FreeDofs()))\n",
    "    B = FETI_Jump(dual_pardofs, u_pardofs=dp_pardofs)\n",
    "    if comm.rank==0:\n",
    "        print('# of global multipliers = :', B.col_pardofs.ndofglobal)\n",
    "    A_dp_inv = LocGlobInverse(A_dp, fes.FreeDofs(), \n",
    "                              invtype_loc='sparsecholesky',\\\n",
    "                              invtype_glob='masterinverse')\n",
    "    F = B @ A_dp_inv @ B.T\n",
    "    innerdofs = BitArray([len(fes.ParallelDofs().Dof2Proc(k))==0 for k in range(fes.ndof)]) & fes.FreeDofs()\n",
    "    A = a.mat.local_mat\n",
    "    Aiinv = A.Inverse(innerdofs, inverse='sparsecholesky')\n",
    "    scaledA = ScaledMat(A, [1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "    scaledBT = ScaledMat(B.T, [1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "    Fhat = B @ A @ (IdentityMatrix() - Aiinv @ A) @ B.T\n",
    "    Fhat2 = B @ scaledA @ (IdentityMatrix() - Aiinv @ A) @ scaledBT\n",
    "    return [A_dp, A_dp_inv, F, Fhat, Fhat2, B, scaledA, scaledBT]\n",
    "    \n",
    "def prep(B, Ainv, f):\n",
    "    rhs.data = (B @ Ainv) * f.vec\n",
    "    return rhs\n",
    "\n",
    "def solve(mat, pre, rhs, sol):\n",
    "    t = comm.WTime()\n",
    "    solvers.CG(mat=mat, pre=pre, rhs=rhs, sol=sol, \\\n",
    "               maxsteps=100, printrates=comm.rank==0, tol=1e-6)\n",
    "    return comm.WTime() - t\n",
    "    \n",
    "def post(B, Ainv, gfu, lam):\n",
    "    hv = B.CreateRowVector()\n",
    "    hv.data = f.vec - B.T * lam\n",
    "    gfu.vec.data = Ainv * hv\n",
    "    jump = lam.CreateVector()\n",
    "    jump.data = B * gfu.vec\n",
    "    norm_jump = Norm(jump)\n",
    "    if comm.rank==0:\n",
    "        print('\\nnorm jump u: ', norm_jump)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:26] \n",
      "global,  ndof = 576877 , lodofs = 75403\n",
      "avg DOFs per core:  6607.56\n",
      "# of global primal dofs:  399\n",
      "# of global multipliers = : 87867\n",
      "\n",
      "Without multiplicity scaling:\n",
      "it =  0  err =  1.3828932193544856\n",
      "it =  1  err =  0.3093351749153988\n",
      "it =  2  err =  0.18687127840143788\n",
      "it =  3  err =  0.15982668754855409\n",
      "it =  4  err =  0.10617262264774008\n",
      "it =  5  err =  0.08636099359688607\n",
      "it =  6  err =  0.06534050969841254\n",
      "it =  7  err =  0.04057663450962983\n",
      "it =  8  err =  0.029795216503713105\n",
      "it =  9  err =  0.019191776032195372\n",
      "it =  10  err =  0.013409312593594812\n",
      "it =  11  err =  0.010581447400156284\n",
      "it =  12  err =  0.008496811350847722\n",
      "it =  13  err =  0.005343556967304746\n",
      "it =  14  err =  0.004037091662300088\n",
      "it =  15  err =  0.0031488610707263364\n",
      "it =  16  err =  0.001955640666077683\n",
      "it =  17  err =  0.0013252053330392102\n",
      "it =  18  err =  0.0008350370464745804\n",
      "it =  19  err =  0.0005929073126342288\n",
      "it =  20  err =  0.0004345439005536961\n",
      "it =  21  err =  0.00030177509231099235\n",
      "it =  22  err =  0.000235125851517741\n",
      "it =  23  err =  0.0001333039960992988\n",
      "it =  24  err =  0.00010897281470394863\n",
      "it =  25  err =  5.7947863178417394e-05\n",
      "it =  26  err =  5.0436899032635396e-05\n",
      "it =  27  err =  2.833460785216941e-05\n",
      "it =  28  err =  2.4503685663038197e-05\n",
      "it =  29  err =  1.5454634417623603e-05\n",
      "it =  30  err =  1.1090913251083186e-05\n",
      "it =  31  err =  7.673508671528517e-06\n",
      "it =  32  err =  4.6991810256279215e-06\n",
      "it =  33  err =  3.5467908187332595e-06\n",
      "it =  34  err =  2.124001011194309e-06\n",
      "it =  35  err =  1.7687264029974105e-06\n",
      "time to solve:  0.6880253719864413\n",
      "\n",
      "With multiplicity scaling:\n",
      "it =  0  err =  0.5664618015376661\n",
      "it =  1  err =  0.1224428346506882\n",
      "it =  2  err =  0.06781240811871787\n",
      "it =  3  err =  0.05441728733297438\n",
      "it =  4  err =  0.030540989185688594\n",
      "it =  5  err =  0.021651861416661526\n",
      "it =  6  err =  0.015835176581462645\n",
      "it =  7  err =  0.009189106784938087\n",
      "it =  8  err =  0.00665121993521837\n",
      "it =  9  err =  0.004076599182832595\n",
      "it =  10  err =  0.0025245723947087415\n",
      "it =  11  err =  0.001596301148038511\n",
      "it =  12  err =  0.0012419381581669539\n",
      "it =  13  err =  0.0008223671367012567\n",
      "it =  14  err =  0.00041055052407066705\n",
      "it =  15  err =  0.00030352885324448577\n",
      "it =  16  err =  0.00021292551429473027\n",
      "it =  17  err =  0.00010535909369377514\n",
      "it =  18  err =  7.858092529530714e-05\n",
      "it =  19  err =  4.1265575251836e-05\n",
      "it =  20  err =  2.8532679697126992e-05\n",
      "it =  21  err =  1.6665576448624424e-05\n",
      "it =  22  err =  9.268060980115929e-06\n",
      "it =  23  err =  7.031947818198997e-06\n",
      "it =  24  err =  3.1407801205245272e-06\n",
      "it =  25  err =  2.5144714637266164e-06\n",
      "it =  26  err =  1.287853499441041e-06\n",
      "it =  27  err =  8.333768290606212e-07\n",
      "\n",
      "time solve without scaling:  0.6880253719864413\n",
      "time solve with scaling:  0.7163726490689442\n",
      "\n",
      "norm jump u:  9.990814869240256e-06\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "comm = MPI_Init()\n",
    "mesh = load_mesh(nref=1)\n",
    "fes, a, f = setup_space(mesh, order=2)\n",
    "A_dp, A_dp_inv, F, Fhat, Fhat2, B, scaledA, scaledBT = setup_FETIDP(fes, a)\n",
    "rhs = B.CreateColVector()\n",
    "lam = B.CreateColVector()\n",
    "prep(B, A_dp_inv, f)\n",
    "if comm.rank==0:\n",
    "    print('\\nWithout multiplicity scaling:')\n",
    "t1 = solve(F,  Fhat,  rhs, lam)\n",
    "if comm.rank==0:\n",
    "    print('time to solve: ', t1)\n",
    "    print('\\nWith multiplicity scaling:')\n",
    "t2 = solve(F,  Fhat2,  rhs, lam)\n",
    "if comm.rank==0:\n",
    "    print('\\ntime solve without scaling: ', t1)\n",
    "    print('time solve with scaling: ', t2)\n",
    "gfu = GridFunction(fes)\n",
    "post(B, A_dp_inv, gfu, lam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SparseCholesky<d,d,d>::MultAdd :   0.8205492496490479\n"
     ]
    }
   ],
   "source": [
    "%%px --target 1\n",
    "for t in sorted(filter(lambda t:t['time']>0.5, Timers()), key=lambda t:t['time'], reverse=True):\n",
    "    print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:31] \n",
      "timers from rank  13 :\n",
      "SparseCholesky<d,d,d>::MultAdd :   1.209223747253418\n",
      "SparseCholesky<d,d,d>::MultAdd fac1 :   0.7094001770019531\n",
      "SparseCholesky<d,d,d>::MultAdd fac2 :   0.4826033115386963\n",
      "SparseCholesky - total :   0.41616296768188477\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "t_chol = filter(lambda t: t['name'] == 'SparseCholesky<d,d,d>::MultAdd', Timers()).__next__()\n",
    "maxt = comm.Max(t_chol['time']) \n",
    "if t_chol['time'] == maxt:\n",
    "    print('timers from rank ', comm.rank, ':')\n",
    "    for t in sorted(filter(lambda t:t['time']>min(0.3*maxt, 0.5), Timers()), key=lambda t:t['time'], reverse=True):\n",
    "        print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster(user_id)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
